package com.snailwu.ratelimiter;

import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.TimeUnit;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.data.redis.connection.RedisConnection;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.connection.ReturnType;

/**
 * @author WuQinglong
 */
public class RedisRateLimiter {

    private final Logger log = LoggerFactory.getLogger(RedisRateLimiter.class);

    private static final String luaScript;

    private final String key;
    private final String max;
    private final String interval;
    private final RedisConnectionFactory connectionFactory;

    static {
        luaScript = "local max=tonumber(ARGV[1]) " + // 单位时间内的最大令牌数
            "local interval=tonumber(ARGV[2]) " + // 生成一个令牌的间隔毫秒数
            "local align=tonumber(ARGV[3]) " + // 当前时间戳（对齐之后的时间戳）
            "align=align-(align%interval) " + // 对齐当前时间戳
            "local latest=redis.call('HGET',KEYS[1],'latest') " + // 获取最后一次成功获取令牌的时间戳
            "if latest then " +
            "   local remain=tonumber(redis.call('HGET',KEYS[1],'remain')) " +
            "   local diff=align-tonumber(latest) " +
            "   if diff>1000 then " +
            "       remain=1 " +
            "   elseif diff>=interval then " +
            "       remain=math.min((math.floor(diff/interval)+remain),max) " +
            "   end " +
            "   if remain<=0 then " +
            "       return 0 " +
            "   else " +
            "       redis.call('HSET',KEYS[1],'remain',remain-1,'latest',align) " +
            "       return 1 " +
            "   end " +
            "else " + // 如果首次获取，直接返回成功，并设置时间为对齐时间
            "    redis.call('HSET',KEYS[1],'remain',0,'latest',align) " +
            "    return 1 " +
            "end";
    }

    public RedisRateLimiter(String key, int max, RedisConnectionFactory connectionFactory) {
        this(key, max, 1000 / max, connectionFactory);
    }

    public RedisRateLimiter(String key, int max, int interval, RedisConnectionFactory connectionFactory) {
        this.key = key;
        this.max = String.valueOf(max);
        this.interval = String.valueOf(interval);
        this.connectionFactory = connectionFactory;
        connectionFactory.getConnection().del(key.getBytes(StandardCharsets.UTF_8));
    }

    /**
     * 简单的重试获取令牌机制
     */
    public boolean tryAcquire(int retryNum, long sleepTime) {
        for (int i = 0; i <= retryNum; i++) {
            if (tryAcquire()) {
                return true;
            }
            if (i == retryNum) {
                log.warn("【触发限流】尝试{}次后仍旧获取不到token。key:{}", retryNum, key);
                return false;
            }
            log.info("获取Token失败，休眠:{}重试。key:{}", sleepTime, key);
            try {
                TimeUnit.MILLISECONDS.sleep(sleepTime);
            } catch (InterruptedException e) {
                // ignore
            }
        }
        return false;
    }

    /**
     * 尝试获取令牌
     *
     * @return true:获取成功；false:获取失败
     */
    public boolean tryAcquire() {
        try (RedisConnection connection = connectionFactory.getConnection()) {
            List<String> keysList = Collections.singletonList(key);
            List<String> argsList = Arrays.asList(max, interval, System.currentTimeMillis() + "");
            byte[][] keysAndArgs = getKeyArgs(keysList, argsList);
            Boolean ret = connection.eval(luaScript.getBytes(StandardCharsets.UTF_8), ReturnType.BOOLEAN,
                keysList.size(), keysAndArgs);
            return Boolean.TRUE.equals(ret);
        } catch (Exception e) {
            log.error("获取令牌时发生异常", e);
        }
        return false;
    }

    private static byte[][] getKeyArgs(List<String> keyList, List<String> argList) {
        byte[][] keyArgs = new byte[keyList.size() + argList.size()][];
        int i = 0;
        for (String key : keyList) {
            keyArgs[i++] = key.getBytes(StandardCharsets.UTF_8);
        }
        for (String arg : argList) {
            keyArgs[i++] = arg.getBytes(StandardCharsets.UTF_8);
        }
        return keyArgs;
    }

}
